{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notebook written by [Zhedong Zheng](https://github.com/zhedongzheng)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import sklearn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "VOCAB_SIZE = 20000\n",
    "EMBED_DIM = 50\n",
    "BATCH_SIZE = 32\n",
    "LR = {'start': 5e-3, 'end': 5e-4, 'steps': 1500}\n",
    "N_EPOCH = 2\n",
    "N_CLASS = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sort_by_len(x, y):\n",
    "    idx = sorted(range(len(x)), key=lambda i: len(x[i]))\n",
    "    return x[idx], y[idx]\n",
    "\n",
    "def pad_sentence_batch(sent_batch):\n",
    "    max_seq_len = max([len(sent) for sent in sent_batch])\n",
    "    padded_seqs = [(sent + [0]*(max_seq_len - len(sent))) for sent in sent_batch]\n",
    "    return padded_seqs\n",
    "\n",
    "def next_train_batch(X_train, y_train):\n",
    "    for i in range(0, len(X_train), BATCH_SIZE):\n",
    "        padded_seqs = pad_sentence_batch(X_train[i : i+BATCH_SIZE])\n",
    "        yield padded_seqs, y_train[i : i+BATCH_SIZE]\n",
    "        \n",
    "def next_test_batch(X_test):\n",
    "    for i in range(0, len(X_test), BATCH_SIZE):\n",
    "        padded_seqs = pad_sentence_batch(X_test[i : i+BATCH_SIZE])\n",
    "        yield padded_seqs\n",
    "        \n",
    "def train_input_fn(X_train, y_train):\n",
    "    dataset = tf.data.Dataset.from_generator(\n",
    "        lambda: next_train_batch(X_train, y_train),\n",
    "        (tf.int32, tf.int64),\n",
    "        (tf.TensorShape([None,None]), tf.TensorShape([None])))\n",
    "    iterator = dataset.make_one_shot_iterator()\n",
    "    return iterator.get_next()\n",
    "\n",
    "def predict_input_fn(X_test):\n",
    "    dataset = tf.data.Dataset.from_generator(\n",
    "        lambda: next_test_batch(X_test),\n",
    "        tf.int32,\n",
    "        tf.TensorShape([None,None]))\n",
    "    iterator = dataset.make_one_shot_iterator()\n",
    "    return iterator.get_next()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward(inputs, mode):\n",
    "    is_training = (mode == tf.estimator.ModeKeys.TRAIN)\n",
    "    masks = tf.abs(tf.sign(inputs))\n",
    "    \n",
    "    x = tf.contrib.layers.embed_sequence(inputs, VOCAB_SIZE, EMBED_DIM)\n",
    "    \n",
    "    # alignment\n",
    "    align = tf.layers.conv1d(x, filters=1, kernel_size=3, padding='same')\n",
    "    align = tf.squeeze(align, -1)\n",
    "    \n",
    "    # masking\n",
    "    paddings = tf.fill(tf.shape(align), float('-inf'))\n",
    "    align = tf.where(tf.equal(masks, 0), paddings, align)\n",
    "    \n",
    "    # probability\n",
    "    align = tf.expand_dims(tf.nn.softmax(align), -1)\n",
    "    \n",
    "    # weighted sum\n",
    "    x = tf.squeeze(tf.matmul(x, align, transpose_a=True), -1)\n",
    "    \n",
    "    logits = tf.layers.dense(x, N_CLASS)\n",
    "    return logits\n",
    "\n",
    "\n",
    "def model_fn(features, labels, mode):\n",
    "    logits = forward(features, mode)\n",
    "    \n",
    "    if mode == tf.estimator.ModeKeys.PREDICT:\n",
    "        preds = tf.argmax(logits, -1)\n",
    "        return tf.estimator.EstimatorSpec(mode, predictions=preds)\n",
    "    \n",
    "    if mode == tf.estimator.ModeKeys.TRAIN:\n",
    "        global_step = tf.train.get_global_step()\n",
    "\n",
    "        lr_op = tf.train.exponential_decay(\n",
    "            LR['start'], global_step, LR['steps'], LR['end']/LR['start'])\n",
    "\n",
    "        loss_op = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
    "            logits=logits, labels=labels))\n",
    "\n",
    "        train_op = tf.train.AdamOptimizer(lr_op).minimize(\n",
    "            loss_op, global_step=global_step)\n",
    "\n",
    "        lth = tf.train.LoggingTensorHook({'lr': lr_op}, every_n_iter=100)\n",
    "        \n",
    "        return tf.estimator.EstimatorSpec(\n",
    "            mode=mode, loss=loss_op, train_op=train_op, training_hooks=[lth])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using default config.\n",
      "WARNING:tensorflow:Using temporary folder as model directory: /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpc7vqekso\n",
      "INFO:tensorflow:Using config: {'_model_dir': '/var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpc7vqekso', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x11423f940>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 1 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpc7vqekso/model.ckpt.\n",
      "INFO:tensorflow:loss = 0.69318783, step = 1\n",
      "INFO:tensorflow:lr = 0.005\n",
      "INFO:tensorflow:global_step/sec: 170.4\n",
      "INFO:tensorflow:loss = 0.360905, step = 101 (0.588 sec)\n",
      "INFO:tensorflow:lr = 0.004288479 (0.588 sec)\n",
      "INFO:tensorflow:global_step/sec: 161.033\n",
      "INFO:tensorflow:loss = 0.23612434, step = 201 (0.621 sec)\n",
      "INFO:tensorflow:lr = 0.0036782112 (0.621 sec)\n",
      "INFO:tensorflow:global_step/sec: 145.414\n",
      "INFO:tensorflow:loss = 0.26872513, step = 301 (0.688 sec)\n",
      "INFO:tensorflow:lr = 0.0031547868 (0.688 sec)\n",
      "INFO:tensorflow:global_step/sec: 139.631\n",
      "INFO:tensorflow:loss = 0.51822644, step = 401 (0.716 sec)\n",
      "INFO:tensorflow:lr = 0.0027058476 (0.716 sec)\n",
      "INFO:tensorflow:global_step/sec: 116.07\n",
      "INFO:tensorflow:loss = 0.28991237, step = 501 (0.862 sec)\n",
      "INFO:tensorflow:lr = 0.0023207942 (0.862 sec)\n",
      "INFO:tensorflow:global_step/sec: 88.7481\n",
      "INFO:tensorflow:loss = 0.22387025, step = 601 (1.127 sec)\n",
      "INFO:tensorflow:lr = 0.0019905358 (1.127 sec)\n",
      "INFO:tensorflow:global_step/sec: 68.3388\n",
      "INFO:tensorflow:loss = 0.26271203, step = 701 (1.463 sec)\n",
      "INFO:tensorflow:lr = 0.0017072745 (1.463 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 782 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpc7vqekso/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 0.1374898.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpc7vqekso/model.ckpt-782\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "\n",
      "Validation Accuracy: 0.9052\n",
      "\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpc7vqekso/model.ckpt-782\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 783 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpc7vqekso/model.ckpt.\n",
      "INFO:tensorflow:loss = 0.3836119, step = 783\n",
      "INFO:tensorflow:lr = 0.0015053472\n",
      "INFO:tensorflow:global_step/sec: 152.83\n",
      "INFO:tensorflow:loss = 0.06968155, step = 883 (0.656 sec)\n",
      "INFO:tensorflow:lr = 0.00129113 (0.656 sec)\n",
      "INFO:tensorflow:global_step/sec: 129.398\n",
      "INFO:tensorflow:loss = 0.099899, step = 983 (0.772 sec)\n",
      "INFO:tensorflow:lr = 0.001107397 (0.772 sec)\n",
      "INFO:tensorflow:global_step/sec: 141.331\n",
      "INFO:tensorflow:loss = 0.1479346, step = 1083 (0.708 sec)\n",
      "INFO:tensorflow:lr = 0.00094980985 (0.708 sec)\n",
      "INFO:tensorflow:global_step/sec: 118.511\n",
      "INFO:tensorflow:loss = 0.32019585, step = 1183 (0.844 sec)\n",
      "INFO:tensorflow:lr = 0.00081464805 (0.844 sec)\n",
      "INFO:tensorflow:global_step/sec: 111.124\n",
      "INFO:tensorflow:loss = 0.12379293, step = 1283 (0.900 sec)\n",
      "INFO:tensorflow:lr = 0.0006987203 (0.900 sec)\n",
      "INFO:tensorflow:global_step/sec: 86.815\n",
      "INFO:tensorflow:loss = 0.12812454, step = 1383 (1.152 sec)\n",
      "INFO:tensorflow:lr = 0.00059928955 (1.152 sec)\n",
      "INFO:tensorflow:global_step/sec: 65.4081\n",
      "INFO:tensorflow:loss = 0.17977262, step = 1483 (1.529 sec)\n",
      "INFO:tensorflow:lr = 0.00051400816 (1.529 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 1564 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpc7vqekso/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 0.04706523.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpc7vqekso/model.ckpt-1564\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "\n",
      "Validation Accuracy: 0.9028\n",
      "\n"
     ]
    }
   ],
   "source": [
    "(X_train, y_train), (X_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=VOCAB_SIZE)\n",
    "X_train, y_train = sort_by_len(X_train, y_train)\n",
    "X_test, y_test = sort_by_len(X_test, y_test)\n",
    "\n",
    "estimator = tf.estimator.Estimator(model_fn)\n",
    "\n",
    "for _ in range(N_EPOCH):\n",
    "    estimator.train(lambda: train_input_fn(X_train, y_train))\n",
    "    y_pred = np.fromiter(estimator.predict(lambda: predict_input_fn(X_test)), np.int32)\n",
    "    print(\"\\nValidation Accuracy: %.4f\\n\" % (y_pred==y_test).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
